{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a56c6855-70f7-4738-ae2c-50c9c80e1894",
   "metadata": {},
   "source": [
    "# Lesson 9: Segmentation\n",
    "\n",
    "- In the classroom, the libraries are already installed for you.\n",
    "- If you would like to run this code on your own machine, you can install the following:\n",
    "\n",
    "```\n",
    "    !pip install transformers\n",
    "    !pip install gradio\n",
    "    !pip install timm\n",
    "    !pip install torchvision\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ad498c8-c553-4afd-9ece-373829881db5",
   "metadata": {},
   "source": [
    "- Here is some code that suppresses warning messages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2c1cfbb-6550-4639-ac92-339f63ca81de",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "from transformers.utils import logging\n",
    "logging.set_verbosity_error()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efe32afb-1dbf-4a02-9cbe-497ed5c5ac97",
   "metadata": {},
   "source": [
    "### Mask Generation with SAM\n",
    "\n",
    "The [Segment Anything Model (SAM)](https://segment-anything.com) model was released by Meta AI."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f61871b2-812e-4862-8418-ccbb513a737f",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "from transformers import pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03a9a3f0-a5c6-430a-acfc-86a77ca95ebd",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "sam_pipe = pipeline(\"mask-generation\",\n",
    "    \"./models/Zigeng/SlimSAM-uniform-77\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f073c4d9",
   "metadata": {},
   "source": [
    "Info about [Zigeng/SlimSAM-uniform-77](https://huggingface.co/Zigeng/SlimSAM-uniform-77)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b8be7ed",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e113a30f-5f10-418c-8775-d0866cc0da64",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "raw_image = Image.open('meta_llamas.jpg')\n",
    "raw_image.resize((720, 375))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9bd34a5-7312-4945-bb8e-19a474f92503",
   "metadata": {},
   "source": [
    "- Running this will take some time\n",
    "- The higher the value of 'points_per_batch', the more efficient pipeline inference will be"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "016e2f3f-27cc-4f9a-b9d8-8799280822ee",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "output = sam_pipe(raw_image, points_per_batch=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd701210-253c-4e6c-a4f8-2a22778c306f",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "from helper import show_pipe_masks_on_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59e3b40b-c1d4-44b8-ad3a-5a85c77faef2",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "show_pipe_masks_on_image(raw_image, output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d4d8f65-0224-4981-a057-741de8aad6ec",
   "metadata": {},
   "source": [
    "_Note:_ The colors of segmentation, that you will get when running this code, might be different than the ones you see in the video."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99c0dc35-a5aa-4f85-9fb0-b21c310b9c3c",
   "metadata": {},
   "source": [
    "### Faster Inference: Infer an Image and a Single Point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a7d88bc-e9aa-4e1a-a1c3-96884ced17bd",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "from transformers import SamModel, SamProcessor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "391efe26-8706-44eb-9cc0-1573609107eb",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "model = SamModel.from_pretrained(\n",
    "    \"./models/Zigeng/SlimSAM-uniform-77\")\n",
    "\n",
    "processor = SamProcessor.from_pretrained(\n",
    "    \"./models/Zigeng/SlimSAM-uniform-77\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d8bdbb3-150c-4e45-9d8d-a03743f4c990",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "raw_image.resize((720, 375))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34928701-d890-4aee-b2dd-8e1814024d34",
   "metadata": {},
   "source": [
    "- Segment the blue shirt Andrew is wearing.\n",
    "- Give any single 2D point that would be in that region (blue shirt)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec7ba89e-e3e3-4c69-8eb8-a6683939f276",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "input_points = [[[1600, 700]]]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f77ac5b0-cfac-4ba1-888b-076dc6c2db25",
   "metadata": {},
   "source": [
    "- Create the input using the image and the single point.\n",
    "- `return_tensors=\"pt\"` means to return PyTorch Tensors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8388fb7-52b1-4a8e-b40f-f0e5e474e52b",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "inputs = processor(raw_image,\n",
    "                 input_points=input_points,\n",
    "                 return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ea8e307-3890-4adc-a2a7-fd9ffc14859d",
   "metadata": {},
   "source": [
    "- Given the inputs, get the output from the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15f7d328-55f4-4b53-ab10-1d6f6160ea49",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6a807d4-077e-4c19-8e3f-7fac5c561c18",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    outputs = model(**inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd87d475-a4f0-4a44-bea7-7decebb2abc3",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "predicted_masks = processor.image_processor.post_process_masks(\n",
    "    outputs.pred_masks,\n",
    "    inputs[\"original_sizes\"],\n",
    "    inputs[\"reshaped_input_sizes\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e7c931e-fe87-4c42-a4ac-08fb19c1e117",
   "metadata": {},
   "source": [
    " Length of `predicted_masks` corresponds to the number of images that are used in the input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a988a82-fb2a-400b-9859-e4a13af3f279",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "len(predicted_masks)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99828ca1-7a9f-4438-81f9-d48381e25d31",
   "metadata": {},
   "source": [
    "- Inspect the size of the first ([0]) predicted mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a27c7cd-82af-46ac-84ee-a0a7291adc03",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "predicted_mask = predicted_masks[0]\n",
    "predicted_mask.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b66f617-b45a-42b6-b799-a99f58807fbe",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "outputs.iou_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19d76624-28c2-4736-82de-174497b8d330",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "from helper import show_mask_on_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c561702-3846-4640-93b1-43ca96e2325d",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "for i in range(3):\n",
    "    show_mask_on_image(raw_image, predicted_mask[:, i])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afbeaa3f-72a5-4a15-a82e-cbba2caf2167",
   "metadata": {},
   "source": [
    "## Depth Estimation with DPT\n",
    "\n",
    "- This model was introduced in the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by Ranftl et al. (2021) and first released in [isl-org/DPT](https://github.com/isl-org/DPT)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e1429d0-b8b5-4899-b468-6fc3a8c84c54",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "depth_estimator = pipeline(task=\"depth-estimation\",\n",
    "                        model=\"./models/Intel/dpt-hybrid-midas\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a05b882",
   "metadata": {},
   "source": [
    "Info about ['Intel/dpt-hybrid-midas'](https://huggingface.co/Intel/dpt-hybrid-midas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c12a3c26-771d-4fe9-aee3-c72e83510edd",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "raw_image = Image.open('gradio_tamagochi_vienna.png')\n",
    "raw_image.resize((806, 621))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7b4475e-1146-4903-b6e4-2ca921e9854e",
   "metadata": {},
   "source": [
    "- If you'd like to generate this image or something like it, check out the short course on [Gradio](https://www.deeplearning.ai/short-courses/building-generative-ai-applications-with-gradio/) and go to the lesson \"Image Generation App\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc181af8-77b4-40cd-a59d-13d4d2e82d84",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "output = depth_estimator(raw_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c70859cb-ed27-4961-9e82-0b8c51371c21",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bba8c341-6722-466b-b067-20f5a255061e",
   "metadata": {},
   "source": [
    "- Post-process the output image to resize it to the size of the original image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5127caa0-1cf3-4c91-ab7e-39cbc832c1ef",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "output[\"predicted_depth\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2357b661-ac23-4813-8b28-2801c138fa75",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "output[\"predicted_depth\"].unsqueeze(1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2e39fd9-17c6-42fe-b87e-b8db44632e6a",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "prediction = torch.nn.functional.interpolate(\n",
    "    output[\"predicted_depth\"].unsqueeze(1),\n",
    "    size=raw_image.size[::-1],\n",
    "    mode=\"bicubic\",\n",
    "    align_corners=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b129d1d-ed14-4cda-a2e0-5256a05769c8",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "prediction.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dac4235-6bb7-42d0-a38b-7893ce0ab72b",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "raw_image.size[::-1],"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3638ab97-6935-4743-9439-03e2e3b822b3",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3609337d-e61b-4d06-b986-8c472cf1c8e3",
   "metadata": {},
   "source": [
    "- Normalize the predicted tensors (between 0 and 255) so that they can be displayed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ca24bf3-f232-4a46-adfe-6c3d2a627a1e",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "import numpy as np "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51e5d5bc-3fdf-4255-8025-51619843ece4",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "output = prediction.squeeze().numpy()\n",
    "formatted = (output * 255 / np.max(output)).astype(\"uint8\")\n",
    "depth = Image.fromarray(formatted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56452340-3cbc-427f-a12d-98c4c7a62305",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "depth"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d4ada74-05c4-474b-907d-578b5f0d9d37",
   "metadata": {},
   "source": [
    "### Demo using Gradio\n",
    "\n",
    "### Troubleshooting Tip\n",
    "- Note, in the classroom, you may see the code for creating the Gradio app run indefinitely.\n",
    "  - This is specific to this classroom environment when it's serving many learners at once, and you won't wouldn't experience this issue if you run this code on your own machine.\n",
    "- To fix this, please restart the kernel (Menu Kernel->Restart Kernel) and re-run the code in the lab from the beginning of the lesson."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67b04985-b3a4-43c8-94ed-5e8a956863c8",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import gradio as gr\n",
    "from transformers import pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5939feab-b84b-4691-826c-afff2d477c60",
   "metadata": {
    "height": 285
   },
   "outputs": [],
   "source": [
    "def launch(input_image):\n",
    "    out = depth_estimator(input_image)\n",
    "\n",
    "    # resize the prediction\n",
    "    prediction = torch.nn.functional.interpolate(\n",
    "        out[\"predicted_depth\"].unsqueeze(1),\n",
    "        size=input_image.size[::-1],\n",
    "        mode=\"bicubic\",\n",
    "        align_corners=False,\n",
    "    )\n",
    "\n",
    "    # normalize the prediction\n",
    "    output = prediction.squeeze().numpy()\n",
    "    formatted = (output * 255 / np.max(output)).astype(\"uint8\")\n",
    "    depth = Image.fromarray(formatted)\n",
    "    return depth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4d7f891-9265-4afa-ab3a-9b5d2aed942e",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "iface = gr.Interface(launch, \n",
    "                     inputs=gr.Image(type='pil'), \n",
    "                     outputs=gr.Image(type='pil'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56f08a4b-3228-4327-850a-3e2211be83d8",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "iface.launch(share=True, server_port=int(os.environ['PORT1']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81555d70-12ec-4e7c-bb11-c1933935bf6d",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "iface.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "057195e2-858b-4a3e-96bc-8148004128e3",
   "metadata": {},
   "source": [
    "### Close the app\n",
    "- Remember to call `.close()` on the Gradio app when you're done using it."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5b878fa",
   "metadata": {},
   "source": [
    "### Try it yourself! \n",
    "- Try this model with your own images!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
